-
Notifications
You must be signed in to change notification settings - Fork 649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow for fast accumulation selection for FP8 GEMM #3416
Conversation
Thank you for the heads up! Also cc @GleasonK |
flax/linen/fp8_ops.py
Outdated
""" | ||
FP8 helper to manage the FP8 meta | ||
""" | ||
FWD_DTYPE: DType = jnp.float8_e4m3fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't observe the usage of FWD/BWD_DTYPE in this commit. Perhaps we should limit this commit to handling accumulation datatypes exclusively. For simplicity, I suggest hardcoding the accumulation datatypes directly within the dot general custom gradient functions as we have already done for the input/output dtypes.
fed6f32
to
ac9e8d2
Compare
4c62cc9
to
1ee5f34
Compare
Codecov Report
@@ Coverage Diff @@
## main #3416 +/- ##
==========================================
+ Coverage 83.50% 83.63% +0.12%
==========================================
Files 56 56
Lines 6725 6802 +77
==========================================
+ Hits 5616 5689 +73
- Misses 1109 1113 +4
|
Imported from GitHub PR #6599 FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue##6168 This PR is closely related to Flax PR-![3416](google/flax#3416). Copybara import of the project: -- a4140da by shuw <shuw@nvidia.com>: Add FP8 fast accumulation support for cublasLt. -- 9684568 by shuw <shuw@nvidia.com>: Improve based on review #1 -- e906d76 by shuw <shuw@nvidia.com>: Improve based on review #2 Merging this change closes #6599 FUTURE_COPYBARA_INTEGRATE_REVIEW=#6599 from wenscarl:fp8_fast_accumulation e906d76 PiperOrigin-RevId: 578904075
Imported from GitHub PR #6599 FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue##6168 This PR is closely related to Flax PR-![3416](google/flax#3416). Copybara import of the project: -- a4140da by shuw <shuw@nvidia.com>: Add FP8 fast accumulation support for cublasLt. -- 9684568 by shuw <shuw@nvidia.com>: Improve based on review #1 -- e906d76 by shuw <shuw@nvidia.com>: Improve based on review #2 Merging this change closes #6599 FUTURE_COPYBARA_INTEGRATE_REVIEW=#6599 from wenscarl:fp8_fast_accumulation e906d76 PiperOrigin-RevId: 578904075
Imported from GitHub PR #6599 FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue##6168 This PR is closely related to Flax PR-![3416](google/flax#3416). Copybara import of the project: -- a4140da by shuw <shuw@nvidia.com>: Add FP8 fast accumulation support for cublasLt. -- 9684568 by shuw <shuw@nvidia.com>: Improve based on review #1 -- e906d76 by shuw <shuw@nvidia.com>: Improve based on review #2 Merging this change closes #6599 FUTURE_COPYBARA_INTEGRATE_REVIEW=#6599 from wenscarl:fp8_fast_accumulation e906d76 PiperOrigin-RevId: 578904075
Imported from GitHub PR #6599 FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue##6168 This PR is closely related to Flax PR-![3416](google/flax#3416). Copybara import of the project: -- a4140da by shuw <shuw@nvidia.com>: Add FP8 fast accumulation support for cublasLt. -- 9684568 by shuw <shuw@nvidia.com>: Improve based on review #1 -- e906d76 by shuw <shuw@nvidia.com>: Improve based on review #2 Merging this change closes #6599 COPYBARA_INTEGRATE_REVIEW=#6599 from wenscarl:fp8_fast_accumulation e906d76 PiperOrigin-RevId: 578948593
Imported from GitHub PR openxla/xla#6599 FP8 cublasLt matmul uses fast accumulation when both operands' precision are DEFAULT. Otherwise fall back to high precision acuumulation. Issue#openxla/xla#6168 This PR is closely related to Flax PR-![3416](google/flax#3416). Copybara import of the project: -- a4140da8ca08cd2d4796a7b8f032827867a361bc by shuw <shuw@nvidia.com>: Add FP8 fast accumulation support for cublasLt. -- 96845683cc4b1e7b947bc919fbf97d8865abeac9 by shuw <shuw@nvidia.com>: Improve based on review #1 -- e906d7620780d2cf1fe8433c933648dcb98dc61d by shuw <shuw@nvidia.com>: Improve based on review #2 Merging this change closes #6599 PiperOrigin-RevId: 578948593
@wenscarl - sorry I was about to merge but there's a fresh merge conflict - would you mind resolving it and I can merge? |
flax/linen/fp8_ops.py
Outdated
@@ -123,6 +123,24 @@ def out_qdq_bwd(compute_dtype, res, g): | |||
out_qdq.defvjp(out_qdq_fwd, out_qdq_bwd) | |||
|
|||
|
|||
@partial(custom_jvp, nondiff_argnums=(2,)) | |||
def dot_general_with_precision(lhs, rhs, dimension_numbers): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you make this have the same signature of the lax.dot_general() so that I can inject it to the jnp.einsum?
@partial(custom_jvp, nondiff_argnums=(2, 3, 4))
def dot_general_with_precision(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None):
return lax.dot_general(lhs, rhs, dimension_numbers, precision=lax.Precision.DEFAULT)
flax/linen/fp8_ops.py
Outdated
precision=lax.Precision.DEFAULT) | ||
|
||
@dot_general_with_precision.defjvp | ||
def dot_general_with_precision_jvp(dimension_numbers, primals, tangents): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this also needs to be changed accordingly, like:
@dot_general_with_precision.defjvp
def dot_general_with_precision_jvp(dimension_numbers, precision, preferred_element_type, primals, tangents):
As for #3416 (comment), can you also rebase your change against the latest? I think the upstream has done some format cleanup. @wenscarl |
d307996
to
2176aac
Compare
2176aac
to
786a111
Compare
@levskaya Can you take a look if the merge can be triggered? |
Sorry could you re-push, there was a stupid warning being triggered that I've now silenced. But beyond that I'm seeing this error in internal CI for
|
@levskaya Could you please review this once more? We have resolved the issue with the failed test, which occurred due to the accidental deletion of a decorator in the previous commit. All tests are now passing successfully. |
@kaixih - should be merged now! sorry for the delay! |
In this PR, issue#6168 is addressed by introducing a custom gradients(forward mode autodiff) for fp8 dot_general. This PR is closely related to FLAX PR-6599@reedwm @kaixih @burmako